package cn.myj2c;

import cn.myj2c.utils.InstructionModifier;
import org.apache.commons.io.IOUtils;
import cn.myj2c.asm.ConstantPool;
import cn.myj2c.executor.Context;
import cn.myj2c.executor.MethodExecutor;
import cn.myj2c.executor.defined.JVMComparisonProvider;
import cn.myj2c.executor.defined.JVMMethodProvider;
import cn.myj2c.executor.defined.MappedMethodProvider;
import cn.myj2c.executor.providers.ComparisonProvider;
import cn.myj2c.executor.providers.DelegatingProvider;
import cn.myj2c.executor.values.JavaValue;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.commons.JSRInlinerAdapter;
import org.objectweb.asm.tree.*;
import org.objectweb.asm.tree.analysis.*;
import org.objectweb.asm.util.CheckClassAdapter;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;

import static org.objectweb.asm.Opcodes.*;

public class AllatoriStringDecrypt {
    private static Map<String, ClassNode> classes = new HashMap<>();
    private final static Map<ClassNode, ConstantPool> constantPools = new HashMap<>();

    public static void main(String[] args) throws IOException {
        String file = "D:\\work\\allatori-deobfuscator\\obfuscator\\allatoriDeobfuscator.jar";
        //预加载项目中的class
        loadInput(new File(file));
        DelegatingProvider provider = new DelegatingProvider();
        provider.register(new JVMMethodProvider());
        provider.register(new JVMComparisonProvider());
        provider.register(new MappedMethodProvider(classes));
        provider.register(new ComparisonProvider() {
            @Override
            public boolean instanceOf(JavaValue target, Type type, Context context) {
                return false;
            }

            @Override
            public boolean checkcast(JavaValue target, Type type, Context context) {
                return true;
            }

            @Override
            public boolean checkEquality(JavaValue first, JavaValue second, Context context) {
                return false;
            }

            @Override
            public boolean canCheckInstanceOf(JavaValue target, Type type, Context context) {
                return false;
            }

            @Override
            public boolean canCheckcast(JavaValue target, Type type, Context context) {
                return true;
            }

            @Override
            public boolean canCheckEquality(JavaValue first, JavaValue second, Context context) {
                return false;
            }
        });
        Set<MethodNode> decryptor = new HashSet<>();

        //解析混淆字符串
        for (ClassNode classNode : classes.values()) {
            for (MethodNode method : classNode.methods) {
                InstructionModifier modifier = new InstructionModifier();

                Frame<SourceValue>[] frames;
                try {
                    frames = new Analyzer<>(new SourceInterpreter()).analyze(classNode.name, method);
                } catch (AnalyzerException e) {
                    continue;
                }
                ListIterator<AbstractInsnNode> iterator = method.instructions.iterator();
                while (iterator.hasNext()) {
                    AbstractInsnNode node = iterator.next();
                    if (!(node instanceof LineNumberNode) && !(node instanceof FrameNode) && !(node instanceof LabelNode)) {
                        if (node.getOpcode() != Opcodes.INVOKESTATIC) { //如果不是静态方法
                            continue;
                        }
                        MethodInsnNode m = (MethodInsnNode) node;
                        if (!m.desc.equals("(Ljava/lang/Object;)Ljava/lang/String;") && !m.desc.equals("(Ljava/lang/String;)Ljava/lang/String;")) {
                            continue;
                        }
                        String targetClass = m.owner;

                        Frame<SourceValue> f = frames[method.instructions.indexOf(m)];
                        if (f.getStack(f.getStackSize() - 1).insns.size() != 1) {
                            continue;
                        }
                        AbstractInsnNode insn = f.getStack(f.getStackSize() - 1).insns.iterator().next();
                        if (insn.getOpcode() != Opcodes.LDC) {
                            continue;
                        }
                        LdcInsnNode ldc = (LdcInsnNode) insn;
                        if (!(ldc.cst instanceof String)) {
                            continue;
                        }
                        /*try {
                            ClassNode targetClassNode = lookupClass(jar, targetClass + ".class");
                            MethodNode decrypterNode = targetClassNode.methods.stream()
                                    .filter(mn -> mn.name.equals(m.name) && mn.desc.equals(m.desc))
                                    .findFirst().orElse(null);
                            if (decrypterNode == null || decrypterNode.instructions.getFirst() == null) {
                                continue;
                            }
                            if (isAllatoriMethod(decrypterNode)) {
                                //反射方法解密字符串，需要把jar加到classpath里面
                                System.out.println("找到加密字符串 类：" + targetClassNode.name + "方法：" + decrypterNode.name + "加密字符串：" + ldc.cst);
                                Class clazz = Class.forName(targetClassNode.name.replaceAll("/", "."));
                                Object obj = clazz.newInstance();
                                Method decrypterMethod = clazz.getDeclaredMethod(decrypterNode.name, String.class);
                                Object invoke = decrypterMethod.invoke(obj, ldc.cst);
                                System.out.println("解密字符串:" + invoke);
                            }
                        } catch (Exception e) {
                            e.printStackTrace();
                        }*/

                        Context context = new Context(provider);
                        context.push(classNode.name, method.name, getConstantPool(classNode).getSize());

                        ClassNode targetClassNode = classes.get(targetClass);
                        if (targetClassNode == null) {
                            continue;
                        }

                        MethodNode decrypterNode = targetClassNode.methods.stream()
                                .filter(mn -> mn.name.equals(m.name) && mn.desc.equals(m.desc))
                                .findFirst().orElse(null);
                        if (decrypterNode == null || decrypterNode.instructions.getFirst() == null) {
                            continue;
                        }

                        if (decryptor.contains(decrypterNode) || isAllatoriMethod(decrypterNode)) {
                            System.out.println("找到加密字符串 类：" + targetClassNode.name + "方法：" + decrypterNode.name + "加密字符串：" + ldc.cst);

                            patchMethod(decrypterNode);
                            try {
                                ldc.cst = MethodExecutor.execute(targetClassNode, decrypterNode,
                                        Collections.singletonList(JavaValue.valueOf(ldc.cst)), null, context);
                                modifier.remove(m);
                                decryptor.add(decrypterNode);
                            } catch (Throwable t) {
                                System.out.println("Error while decrypting Allatori string.");
                                System.out.println("Are you sure you're deobfuscating something obfuscated by Allatori?");
                                System.out.println(classNode.name + " " + method.name + method.desc + " " + m.owner + " " + m.name + m.desc);
                                t.printStackTrace(System.out);
                            }
                        }
                    }
                }
                modifier.apply(method);
            }
        }
        cleanup(decryptor);
        ZipOutputStream zipOut = new ZipOutputStream(new FileOutputStream(file.replace(".jar", ".out.jar")));


        classes.values().forEach(classNode -> {
            try {
                byte[] b = toByteArray(classNode);
                if (b != null) {
                    zipOut.putNextEntry(new ZipEntry(classNode.name + ".class"));
                    zipOut.write(b);
                    zipOut.closeEntry();
                }
            } catch (IOException e) {
                System.out.println("Error writing entry " + classNode.name + e.getMessage());
            }
        });

        zipOut.close();
    }

    private static byte[] toByteArray(ClassNode node) {
        if (node.innerClasses != null) {
            node.innerClasses.stream().filter(in -> in.innerName != null).forEach(in -> {
                if (in.innerName.indexOf('/') != -1) {
                    in.innerName = in.innerName.substring(in.innerName.lastIndexOf('/') + 1); //Stringer
                }
            });
        }
        ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
        try {
            node.accept(writer);
        } catch (Throwable e) {
            System.out.println("Error while writing " + node.name);
            e.printStackTrace(System.out);
        }
        byte[] classBytes = writer.toByteArray();

        ClassReader cr = new ClassReader(classBytes);
        try {
            cr.accept(new CheckClassAdapter(new ClassWriter(0)), 0);
        } catch (Throwable t) {
            System.out.println("Error: " + node.name + " failed verification");
            t.printStackTrace(System.out);
        }

        return classBytes;
    }

    private static void patchMethod(MethodNode method) {
        boolean getStackTrace = false;
        boolean getClassName = false;
        for (AbstractInsnNode i = method.instructions.getFirst(); i != null; i = i.getNext()) {
            if (!(i instanceof MethodInsnNode)) {
                continue;
            }
            String name = ((MethodInsnNode) i).name;
            if (!getStackTrace && name.equals("getStackTrace")) {
                getStackTrace = true;
                if (getClassName) {
                    break;
                }
            } else if (!getClassName && name.equals("getClassName")) {
                getClassName = true;
                if (getStackTrace) {
                    break;
                }
            }
        }
        if (!getClassName || !getStackTrace) {
            return;
        }
        for (AbstractInsnNode insn = method.instructions.getFirst(); insn != null; insn = insn.getNext()) {
            if (insn.getOpcode() == Opcodes.NEW) {
                TypeInsnNode typeInsn = (TypeInsnNode) insn;
                if (typeInsn.desc.endsWith("Exception") || typeInsn.desc.endsWith("Error")) {
                    typeInsn.desc = "java/lang/RuntimeException";
                }
            } else if (insn instanceof MethodInsnNode) {
                MethodInsnNode methodInsn = (MethodInsnNode) insn;
                if (methodInsn.owner.endsWith("Exception") || methodInsn.owner.endsWith("Error")) {
                    methodInsn.owner = "java/lang/RuntimeException";
                }
            }
        }
    }

    private static void loadInput(File file) throws IOException {
        try (ZipFile zipIn = new ZipFile(file)) {
            Enumeration<? extends ZipEntry> e = zipIn.entries();
            while (e.hasMoreElements()) {
                ZipEntry next = e.nextElement();
                if (next.isDirectory() || next.getName().endsWith(".class/")) {
                    continue;
                }
                byte[] data = IOUtils.toByteArray(zipIn.getInputStream(next));
                loadInput(next.getName(), data);
            }
        }
    }

    public static void loadInput(String name, byte[] data) {
        if (name.endsWith(".class") || name.endsWith(".class/")) {
            if (data.length <= 30) {
                return;
            }
            try {
                ClassReader reader = new ClassReader(data);
                ClassNode node = new ClassNode();
                reader.accept(node, ClassReader.SKIP_FRAMES);

                for (int i = 0; i < node.methods.size(); i++) {
                    MethodNode methodNode = node.methods.get(i);
                    JSRInlinerAdapter adapter = new JSRInlinerAdapter(
                            methodNode,
                            methodNode.access,
                            methodNode.name,
                            methodNode.desc,
                            methodNode.signature,
                            methodNode.exceptions.toArray(new String[0]));
                    methodNode.accept(adapter);
                    node.methods.set(i, adapter);
                }

                classes.put(node.name, node);
                setConstantPool(node, new ConstantPool(reader));
            } catch (IllegalArgumentException | IndexOutOfBoundsException x) {
                x.printStackTrace();
            }
        }
    }


    private static boolean isAllatoriMethod(MethodNode decryptorNode) {
        boolean isAllatori = true;
        isAllatori = isAllatori && containsInvokeVirtual(decryptorNode, "java/lang/String", "charAt", "(I)C");
        isAllatori = isAllatori && containsInvokeVirtual(decryptorNode, "java/lang/String", "length", "()I");
        isAllatori = isAllatori && containsInvokeSpecial(decryptorNode, "java/lang/String", "<init>", null);
        isAllatori = isAllatori && countOccurencesOf(decryptorNode, IXOR) > 2;
        isAllatori = isAllatori && countOccurencesOf(decryptorNode, NEWARRAY) > 0;
        return isAllatori;
    }

    public static boolean containsInvokeVirtual(MethodNode methodNode, String owner, String name, String desc) {
        for (AbstractInsnNode insn : methodNode.instructions) {
            if (isInvokeVirtual(insn, owner, name, desc)) {
                return true;
            }
        }
        return false;
    }

    public static boolean isInvokeVirtual(AbstractInsnNode insn, String owner, String name, String desc) {
        if (insn == null) {
            return false;
        }
        if (insn.getOpcode() != INVOKEVIRTUAL) {
            return false;
        }
        MethodInsnNode methodInsnNode = (MethodInsnNode) insn;
        return (owner == null || methodInsnNode.owner.equals(owner)) &&
                (name == null || methodInsnNode.name.equals(name)) &&
                (desc == null || methodInsnNode.desc.equals(desc));
    }

    public static boolean containsInvokeSpecial(MethodNode methodNode, String owner, String name, String desc) {
        for (AbstractInsnNode insn : methodNode.instructions) {
            if (isInvokeSpecial(insn, owner, name, desc)) {
                return true;
            }
        }
        return false;
    }

    public static boolean isInvokeSpecial(AbstractInsnNode insn, String owner, String name, String desc) {
        if (insn == null) {
            return false;
        }
        if (insn.getOpcode() != INVOKESPECIAL) {
            return false;
        }
        MethodInsnNode methodInsnNode = (MethodInsnNode) insn;
        return (owner == null || methodInsnNode.owner.equals(owner)) &&
                (name == null || methodInsnNode.name.equals(name)) &&
                (desc == null || methodInsnNode.desc.equals(desc));
    }

    public static int countOccurencesOf(MethodNode methodNode, int opcode) {
        int i = 0;
        for (AbstractInsnNode insnNode : methodNode.instructions) {
            if (insnNode.getOpcode() == opcode) {
                i++;
            }
        }
        return i;
    }

    public static ConstantPool getConstantPool(ClassNode classNode) {
        return constantPools.get(classNode);
    }

    public static void setConstantPool(ClassNode owner, ConstantPool pool) {
        constantPools.put(owner, pool);
    }

    public static Collection<ClassNode> classNodes() {
        return classes.values();
    }


    private static int cleanup(Set<MethodNode> toRemove) {
        classNodes().forEach(node -> node.methods.forEach(methodNode -> {
            for (AbstractInsnNode insn : methodNode.instructions) {
                if (insn.getOpcode() != INVOKESTATIC) {
                    continue;
                }
                MethodInsnNode m = (MethodInsnNode) insn;
                ClassNode owner = classes.get(m.owner);
                if (owner == null) {
                    continue;
                }
                MethodNode mNode = owner.methods.stream().filter(mn -> mn.name.equals(m.name) && mn.desc.equals(m.desc)).findFirst().orElse(null);
                toRemove.remove(mNode);
            }
        }));
        AtomicInteger count = new AtomicInteger(0);
        classNodes().forEach(classNode -> {
            for (Iterator<MethodNode> it = classNode.methods.iterator(); it.hasNext(); ) {
                MethodNode methodNode = it.next();
                if (toRemove.remove(methodNode)) {
                    it.remove();
                    count.getAndIncrement();
                }
            }
        });
        return count.get();
    }

}
